{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/miniconda3/envs/yuh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Loading checkpoint shards: 100%|██████████| 30/30 [02:50<00:00,  5.68s/it]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "import evaluate\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer\n",
    "import torch\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from evaluate import logging\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "\n",
    "class Perplexity():\n",
    "\n",
    "    def compute(\n",
    "        self, predictions, model, tokenizer, batch_size: int = 4, add_start_token: bool = True, max_length=None,\n",
    "        output_dir = None\n",
    "    ):\n",
    "\n",
    "        # if batch_size > 1 (which generally leads to padding being required), and\n",
    "        # if there is not an already assigned pad_token, assign an existing\n",
    "        # special token to also be the padding token\n",
    "        if tokenizer.pad_token is None and batch_size > 1:\n",
    "            existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())\n",
    "            # check that the model already has at least one special token defined\n",
    "            assert (\n",
    "                len(existing_special_tokens) > 0\n",
    "            ), \"If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1.\"\n",
    "            # assign one of the special tokens to also be the pad token\n",
    "            tokenizer.add_special_tokens({\"pad_token\": existing_special_tokens[0]})\n",
    "\n",
    "        if add_start_token and max_length:\n",
    "            # leave room for <BOS> token to be added:\n",
    "            assert (\n",
    "                tokenizer.bos_token is not None\n",
    "            ), \"Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False\"\n",
    "            max_tokenized_len = max_length - 1\n",
    "        else:\n",
    "            max_tokenized_len = max_length\n",
    "\n",
    "        encodings = tokenizer(\n",
    "            predictions,\n",
    "            add_special_tokens=False,\n",
    "            padding=True,\n",
    "            truncation=True if max_tokenized_len else False,\n",
    "            max_length=max_tokenized_len,\n",
    "            return_tensors=\"pt\",\n",
    "            return_attention_mask=True,\n",
    "        )\n",
    "\n",
    "        encoded_texts = encodings[\"input_ids\"]\n",
    "        attn_masks = encodings[\"attention_mask\"]\n",
    "\n",
    "        # check that each input is long enough:\n",
    "        if add_start_token:\n",
    "            assert torch.all(torch.ge(attn_masks.sum(1), 1)), \"Each input text must be at least one token long.\"\n",
    "        else:\n",
    "            assert torch.all(\n",
    "                torch.ge(attn_masks.sum(1), 2)\n",
    "            ), \"When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings.\"\n",
    "\n",
    "        ppls = []\n",
    "        loss_fct = CrossEntropyLoss(reduction=\"none\")\n",
    "\n",
    "        for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):\n",
    "            end_index = min(start_index + batch_size, len(encoded_texts))\n",
    "            encoded_batch = encoded_texts[start_index:end_index]\n",
    "            attn_mask = attn_masks[start_index:end_index]\n",
    "\n",
    "            if add_start_token:\n",
    "                bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0))\n",
    "                encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)\n",
    "                attn_mask = torch.cat(\n",
    "                    [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64), attn_mask], dim=1\n",
    "                )\n",
    "\n",
    "            labels = encoded_batch\n",
    "\n",
    "            with torch.no_grad():\n",
    "                out_logits = model(encoded_batch, attention_mask=attn_mask).logits\n",
    "\n",
    "            shift_logits = out_logits[..., :-1, :].contiguous()\n",
    "            shift_labels = labels[..., 1:].contiguous()\n",
    "            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()\n",
    "\n",
    "            perplexity_batch = torch.exp(\n",
    "                (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)\n",
    "                / shift_attention_mask_batch.sum(1)\n",
    "            )\n",
    "\n",
    "            ppls += perplexity_batch.tolist()\n",
    "            if output_dir is not None:\n",
    "                with open(output_dir / \"ppls.json\", \"w\") as f:\n",
    "                    json.dump({\"perplexities\": ppls}, f)\n",
    "\n",
    "        return {\"perplexities\": ppls, \"mean_perplexity\": np.mean(ppls)}\n",
    "\n",
    "perplexity = Perplexity()\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "                \"/hpc2hdd/home/xzou428/Yuhao/llama3-70b-instruct\", device_map=\"auto\", torch_dtype=torch.bfloat16, attn_implementation=\"flash_attention_2\"\n",
    "        )\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "            \"/hpc2hdd/home/xzou428/Yuhao/llama3-70b-instruct\"\n",
    "        )\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/miniconda3/envs/yuh/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250/250 [1:07:13<00:00, 16.14s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [2.4998159408569336, 2.9856221675872803, 2.4024786949157715, 5.379189968109131, 4.443045616149902, 3.92936110496521, 2.9446799755096436, 3.7958059310913086, 2.32310152053833, 2.989179849624634, 8.367166519165039, 3.9714832305908203, 4.470235347747803, 2.532318353652954, 2.378948211669922, 3.4641895294189453, 3.7956104278564453, 4.248391151428223, 6.751543998718262, 2.0880579948425293, 1.047723650932312, 4.846902370452881, 3.189065456390381, 3.957723617553711, 3.2210166454315186, 3.8150815963745117, 3.809826135635376, 3.5657310485839844, 3.9706974029541016, 2.909991502761841, 3.0065016746520996, 2.670348882675171, 3.6972227096557617, 3.3581321239471436, 4.191929817199707, 3.8466880321502686, 4.952792644500732, 4.243938446044922, 4.211695671081543, 3.15506649017334, 2.7492849826812744, 4.692835330963135, 3.0737152099609375, 4.440592288970947, 3.530892848968506, 4.226004600524902, 3.0147857666015625, 2.667144298553467, 2.9659416675567627, 4.199448108673096, 2.5692694187164307, 2.9836478233337402, 3.350912570953369, 3.8716747760772705, 4.384750843048096, 3.432697296142578, 4.111363887786865, 3.757571220397949, 2.9869165420532227, 3.7190914154052734, 3.4831509590148926, 2.718311309814453, 4.266922473907471, 4.0299553871154785, 2.9744741916656494, 3.526282787322998, 3.703972101211548, 2.856518507003784, 2.539314031600952, 10.316502571105957, 3.7879374027252197, 3.529261350631714, 4.515599727630615, 6.420227527618408, 4.283083915710449, 2.2194032669067383, 4.98930025100708, 2.6969966888427734, 4.543152332305908, 3.759730815887451, 3.911067247390747, 2.9984242916107178, 4.535149574279785, 3.6181752681732178, 4.8591156005859375, 3.1850883960723877, 3.230882167816162, 6.048998832702637, 3.142192840576172, 4.261178016662598, 3.0015528202056885, 2.735477924346924, 2.723846197128296, 3.368540048599243, 3.844893217086792, 5.0767340660095215, 4.116819381713867, 4.865103721618652, 4.018373012542725, 4.1121954917907715, 4.100131511688232, 3.51717209815979, 2.755133867263794, 3.620913028717041, 4.374685764312744, 3.5372726917266846, 2.9552271366119385, 4.978931427001953, 3.632948875427246, 4.332434177398682, 3.613905668258667, 5.471495151519775, 5.22836446762085, 3.2737507820129395, 5.373469829559326, 6.063565731048584, 1.048418402671814, 3.259190082550049, 5.569326877593994, 3.8920702934265137, 6.198002815246582, 2.852323532104492, 3.6489038467407227, 5.35780668258667, 3.1338019371032715, 3.6744658946990967, 4.771067142486572, 3.6330454349517822, 3.565450429916382, 5.983518123626709, 1.0777326822280884, 3.7394015789031982, 3.821810007095337, 4.101750373840332, 3.598283052444458, 4.660801887512207, 4.617112159729004, 9.325841903686523, 4.101461887359619, 2.377406120300293, 3.200939655303955, 2.5365543365478516, 4.131699562072754, 1.0486876964569092, 3.705716609954834, 4.595509052276611, 3.243587017059326, 5.738552570343018, 6.1697821617126465, 2.654228448867798, 3.643991708755493, 3.7304532527923584, 4.771766185760498, 2.876300573348999, 3.0325777530670166, 5.182375431060791, 3.56845760345459, 3.4192399978637695, 6.366179943084717, 4.2076945304870605, 2.779252052307129, 2.558641195297241, 3.0555269718170166, 3.1571967601776123, 3.205886125564575, 6.336745738983154, 2.4640209674835205, 7.518610954284668, 2.846424102783203, 3.8231396675109863, 4.122981548309326, 2.0860588550567627, 2.6100103855133057, 4.2471232414245605, 3.0355801582336426, 4.023844242095947, 5.968391418457031, 2.4773919582366943, 4.353003978729248, 4.868045806884766, 3.554452419281006, 3.85398530960083, 5.962102890014648, 5.756697654724121, 3.905735731124878, 5.239724636077881, 5.2634382247924805, 1.0446748733520508, 5.055320739746094, 5.147263050079346, 4.0513224601745605, 3.9955294132232666, 3.516756534576416, 4.134942531585693, 4.2272419929504395, 6.466255187988281, 3.3598389625549316, 3.3400719165802, 4.607227325439453, 2.550360679626465, 4.45184850692749, 6.812209606170654, 2.097658395767212, 3.4038188457489014, 3.2298948764801025, 4.636793613433838, 5.776954650878906, 3.8996119499206543, 3.8745603561401367, 2.864074230194092, 5.2852911949157715, 3.4117684364318848, 3.8200671672821045, 4.370829105377197, 4.893868446350098, 4.299007415771484, 1.059079647064209, 6.152494430541992, 3.3304624557495117, 5.921623229980469, 1.1148720979690552, 2.9449946880340576, 3.9610695838928223, 2.391352653503418, 3.1493079662323, 2.5842373371124268, 3.6976630687713623, 4.378880977630615, 3.9818437099456787, 4.736690998077393, 3.2866199016571045, 3.3378186225891113, 2.6978542804718018, 3.9827637672424316, 3.629607915878296, 4.032090663909912, 2.946438789367676, 2.6259000301361084, 4.202783107757568, 3.1678214073181152, 3.3721415996551514, 4.0680060386657715, 3.2613136768341064, 4.357128620147705, 5.98230504989624, 3.327826976776123, 3.7431042194366455, 2.3513412475585938, 6.956093788146973, 4.177375316619873, 4.353437900543213, 2.592355251312256, 3.1734540462493896, 3.649700403213501, 3.673336982727051, 4.134865760803223, 2.0574491024017334, 3.1926000118255615, 3.5099382400512695, 2.696363687515259, 4.211073398590088, 3.31341290473938, 3.719832420349121, 6.694586753845215, 2.9891579151153564, 3.096970796585083, 4.105993270874023, 3.830620288848877, 3.0067665576934814, 2.8758697509765625, 4.126051425933838, 2.799138069152832, 2.7340643405914307, 4.912525653839111, 4.138279438018799, 3.0719990730285645, 3.131398916244507, 2.910111904144287, 3.766773223876953, 3.4446699619293213, 4.072597980499268, 5.992519855499268, 3.8411924839019775, 3.9139771461486816, 2.9384634494781494, 2.9797585010528564, 3.6218671798706055, 5.383536338806152, 3.2639718055725098, 3.8588223457336426, 3.8623950481414795, 4.854751110076904, 3.046217441558838, 4.314635276794434, 4.497443675994873, 3.35622239112854, 2.949951171875, 3.8200249671936035, 5.8194780349731445, 1.066753625869751, 4.007159233093262, 2.3772428035736084, 1.0592585802078247, 4.174618244171143, 4.019759654998779, 3.485914945602417, 1.9573867321014404, 4.326868057250977, 2.637583017349243, 5.954212188720703, 3.3530867099761963, 4.5829291343688965, 1.1054649353027344, 2.848069429397583, 1.359129786491394, 4.031626224517822, 1.0841768980026245, 4.329006195068359, 3.8303043842315674, 3.692148447036743, 3.9946682453155518, 3.1820037364959717, 3.8482067584991455, 2.2990212440490723, 4.210283279418945, 2.7233238220214844, 4.0636162757873535, 2.3887758255004883, 3.948409080505371, 3.640626907348633, 3.8529212474823, 3.863670825958252, 5.363521575927734, 4.369475364685059, 6.406410217285156, 2.9305474758148193, 3.539325475692749, 4.844107151031494, 4.718592643737793, 3.0701394081115723, 3.7373783588409424, 4.176374435424805, 4.240476608276367, 7.164785385131836, 3.3848156929016113, 5.551575660705566, 3.1341142654418945, 3.9152231216430664, 3.280503273010254, 3.3853907585144043, 4.572758674621582, 3.534146785736084, 4.080906391143799, 2.5508742332458496, 2.9766905307769775, 3.262017011642456, 2.943415880203247, 5.115331649780273, 4.401778221130371, 3.491844892501831, 3.1716458797454834, 3.8419723510742188, 3.263272285461426, 3.064347505569458, 2.506976366043091, 3.9556972980499268, 3.607717514038086, 2.5764200687408447, 4.88701868057251, 2.9205894470214844, 3.246717929840088, 2.2552714347839355, 3.3336079120635986, 3.025580883026123, 4.289077281951904, 3.650944948196411, 4.537552356719971, 2.0748507976531982, 1.5643894672393799, 2.028895616531372, 4.075577735900879, 3.091153621673584, 3.7950987815856934, 3.093907356262207, 4.984325408935547, 4.940251350402832, 3.7562508583068848, 4.2696685791015625, 4.267434597015381, 5.108110427856445, 2.190427541732788, 3.0951967239379883, 7.042504787445068, 2.3200790882110596, 3.962284803390503, 4.5884199142456055, 5.876114368438721, 4.279589653015137, 4.223823070526123, 3.5432486534118652, 3.0483462810516357, 2.7412569522857666, 2.860344409942627, 4.929261207580566, 2.8346877098083496, 3.1407155990600586, 3.7099087238311768, 4.627816200256348, 4.222611427307129, 2.564450979232788, 3.2144508361816406, 2.6263327598571777, 2.412083148956299, 5.959141254425049, 4.502986907958984, 5.073246955871582, 4.6628499031066895, 6.383094310760498, 3.8142313957214355, 4.421656608581543, 2.4827933311462402, 3.1327106952667236, 4.705959320068359, 3.046621799468994, 6.0037994384765625, 5.301602840423584, 3.0854573249816895, 4.214744567871094, 3.783942937850952, 4.247945785522461, 2.5481956005096436, 1.8167469501495361, 3.473729372024536, 4.714054107666016, 3.6267454624176025, 4.115428447723389, 2.9184882640838623, 4.711605548858643, 4.567840576171875, 5.857510089874268, 3.6042065620422363, 3.292325496673584, 3.108452796936035, 3.6961605548858643, 6.007719039916992, 2.3917059898376465, 4.269360065460205, 2.2346415519714355, 2.296097993850708, 4.2974724769592285, 3.2656896114349365, 5.294939994812012, 5.415618896484375, 3.3533644676208496, 4.73203706741333, 4.250616073608398, 6.568949222564697, 4.233122825622559, 4.633768558502197, 3.975574016571045, 3.3460710048675537, 4.958654403686523, 4.780910015106201, 4.340291976928711, 4.0094194412231445, 3.5315685272216797, 4.150110244750977, 1.5575271844863892, 3.3129255771636963, 2.23036527633667, 3.424028158187866, 3.186516761779785, 4.28557014465332, 2.67166805267334, 4.1775102615356445, 2.70381236076355, 2.818509817123413, 3.60538649559021, 2.33084774017334, 3.4882144927978516, 2.829195737838745, 3.357194185256958, 4.1775665283203125, 3.593625783920288, 3.5017786026000977, 3.0626416206359863, 3.070733070373535, 1.8282601833343506, 4.302586555480957, 2.514493465423584, 4.698977947235107, 6.408092021942139, 3.532930374145508, 3.2562460899353027, 5.743214130401611, 3.1490414142608643, 3.5376698970794678, 3.9517178535461426, 4.7700018882751465, 2.9389030933380127], 'mean_perplexity': 3.8049056062698363}\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "from pathlib import Path\n",
    "\n",
    "predictions = []\n",
    "predict_dir = \"checkpoints/higpt-stage2-llama3-new_batch_rw-epoch30-8192-full_finetune/lightning_logs/version_9/predict\"\n",
    "predict_files = list(Path(predict_dir).glob(\"*.txt\"))\n",
    "for file in predict_files:\n",
    "    with open(file, \"r\") as f:\n",
    "        predictions.append(f.read().strip())\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model=model, tokenizer=tokenizer, batch_size=2,\n",
    "                                output_dir = Path(predict_dir))\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250/250 [1:08:27<00:00, 16.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [7.220343112945557, 8.96116828918457, 12.989233016967773, 10.621763229370117, 4.5849409103393555, 11.24061393737793, 1.1187939643859863, 6.282163619995117, 4.431479454040527, 5.874669551849365, 7.204396724700928, 5.251492023468018, 8.113046646118164, 4.6291680335998535, 7.889655590057373, 8.999985694885254, 5.765640735626221, 5.632110595703125, 12.575323104858398, 4.418312072753906, 9.290485382080078, 10.942720413208008, 6.392186641693115, 12.512866973876953, 9.693538665771484, 9.394514083862305, 8.571351051330566, 8.944293022155762, 7.674201011657715, 5.082426071166992, 4.278109073638916, 5.771695613861084, 10.285709381103516, 3.845829725265503, 13.966651916503906, 8.982160568237305, 7.953607082366943, 10.174062728881836, 7.734488487243652, 9.132355690002441, 1.0851534605026245, 9.443883895874023, 7.021188735961914, 10.995126724243164, 9.3703031539917, 6.958315849304199, 7.682322025299072, 5.866414546966553, 9.120397567749023, 8.248562812805176, 4.903852462768555, 11.392155647277832, 9.49405574798584, 11.316476821899414, 19.451744079589844, 6.9229278564453125, 12.64656925201416, 6.4150872230529785, 8.24782943725586, 8.002391815185547, 5.8502326011657715, 4.1947736740112305, 7.609427452087402, 3.5787806510925293, 7.035768032073975, 7.896481037139893, 7.550848007202148, 5.872852325439453, 4.217881202697754, 9.962814331054688, 9.17367935180664, 7.638707160949707, 8.688178062438965, 4.355037212371826, 13.634010314941406, 6.235501289367676, 8.066516876220703, 4.078326225280762, 9.71679401397705, 14.559103012084961, 12.22741413116455, 5.265911102294922, 13.442512512207031, 7.251509189605713, 5.035361289978027, 7.570530414581299, 8.253658294677734, 15.71534252166748, 5.03449821472168, 7.56925106048584, 6.439436435699463, 3.560218334197998, 7.2441792488098145, 6.160134792327881, 7.4211602210998535, 4.697006702423096, 9.276329040527344, 15.749652862548828, 9.126855850219727, 10.86497974395752, 8.89433479309082, 9.660564422607422, 6.607327938079834, 6.109801769256592, 10.741618156433105, 6.628200531005859, 7.171274185180664, 8.43789291381836, 7.61118745803833, 11.833160400390625, 4.503620147705078, 12.77609920501709, 10.618331909179688, 7.197545051574707, 11.016772270202637, 3.5174190998077393, 8.35169792175293, 6.892431735992432, 6.624718189239502, 6.963745594024658, 17.852022171020508, 3.018238067626953, 9.243905067443848, 3.942542552947998, 7.599405288696289, 8.596990585327148, 7.727963447570801, 5.031185150146484, 4.7419304847717285, 12.599516868591309, 7.149275302886963, 6.931704998016357, 6.939282417297363, 10.119980812072754, 7.262596607208252, 3.110896587371826, 9.149092674255371, 12.27851390838623, 1.1196651458740234, 6.186306476593018, 10.388976097106934, 5.931013584136963, 6.802693843841553, 3.672261953353882, 3.0059285163879395, 16.050884246826172, 7.533730506896973, 11.415706634521484, 9.699600219726562, 3.74501895904541, 1.1301047801971436, 4.894139289855957, 13.44692325592041, 5.381185054779053, 4.677361488342285, 6.731645107269287, 6.703747272491455, 8.357439041137695, 11.068706512451172, 12.2701416015625, 8.499791145324707, 1.1916362047195435, 4.688991546630859, 3.5896248817443848, 1.1452391147613525, 5.439231872558594, 8.432036399841309, 14.158225059509277, 1.072849988937378, 4.284598350524902, 7.548576354980469, 1.0746533870697021, 6.8942108154296875, 8.472738265991211, 5.157893180847168, 9.168868064880371, 12.57855224609375, 6.726876258850098, 8.923368453979492, 10.231858253479004, 5.3236985206604, 6.439993858337402, 7.80320405960083, 12.051140785217285, 10.582963943481445, 11.359773635864258, 6.424800395965576, 8.17261028289795, 10.90157413482666, 5.746174335479736, 9.15550422668457, 9.571870803833008, 5.548766136169434, 7.569949150085449, 6.359655857086182, 17.846878051757812, 8.036689758300781, 6.193211555480957, 7.188570499420166, 5.709229946136475, 7.31947135925293, 14.465582847595215, 8.428372383117676, 9.922321319580078, 7.509680271148682, 5.906876564025879, 10.65238094329834, 5.699851036071777, 8.585233688354492, 3.5278782844543457, 12.197339057922363, 6.565845966339111, 9.886456489562988, 2.3638486862182617, 13.902608871459961, 9.749185562133789, 7.264040946960449, 6.163508415222168, 10.603315353393555, 9.863275527954102, 13.064468383789062, 7.756284236907959, 7.4851250648498535, 3.8213508129119873, 13.399558067321777, 3.0382399559020996, 7.187049388885498, 11.86284351348877, 8.502946853637695, 8.574925422668457, 8.412863731384277, 5.036327838897705, 10.819705963134766, 8.64692211151123, 6.919644832611084, 7.937554359436035, 5.44537878036499, 1.0613434314727783, 7.1330437660217285, 10.071870803833008, 5.686461448669434, 5.471285820007324, 7.8208160400390625, 5.048673629760742, 11.589432716369629, 8.232829093933105, 9.158036231994629, 1.1788792610168457, 1.077649712562561, 8.10289192199707, 4.585588455200195, 9.002124786376953, 6.267740726470947, 4.161913871765137, 6.620725154876709, 14.308003425598145, 3.315250873565674, 4.411523342132568, 7.886943340301514, 7.05299186706543, 5.398303031921387, 7.905517101287842, 5.551706790924072, 10.043386459350586, 8.768553733825684, 2.8443737030029297, 14.773366928100586, 6.622905254364014, 10.352378845214844, 9.957834243774414, 3.354011297225952, 5.901095390319824, 6.8740153312683105, 1.2803819179534912, 9.318229675292969, 10.565739631652832, 10.265250205993652, 8.616751670837402, 7.6539716720581055, 5.58084774017334, 3.432999849319458, 16.084611892700195, 9.179757118225098, 5.075444221496582, 6.671141147613525, 9.25600814819336, 7.912884712219238, 9.213436126708984, 4.6840715408325195, 3.844160556793213, 7.096933364868164, 9.398102760314941, 7.295798301696777, 10.756917953491211, 7.250435829162598, 9.845826148986816, 4.78450870513916, 9.741100311279297, 12.252110481262207, 7.319858074188232, 7.120351791381836, 3.895526885986328, 8.905900001525879, 7.910192966461182, 14.293567657470703, 7.489871978759766, 5.011918067932129, 5.505130767822266, 2.761073350906372, 9.360201835632324, 6.980151653289795, 5.901315689086914, 16.986019134521484, 4.305767059326172, 5.621553897857666, 5.981314659118652, 10.398664474487305, 9.432245254516602, 6.666719913482666, 8.02414321899414, 1.2547578811645508, 11.206165313720703, 8.056528091430664, 5.68672513961792, 1.0809593200683594, 7.599495887756348, 8.140213012695312, 4.131028175354004, 9.198101043701172, 11.209614753723145, 9.144412994384766, 1.1201035976409912, 7.332896709442139, 8.845931053161621, 6.8972015380859375, 3.268385171890259, 3.4898288249969482, 5.064342021942139, 8.95376205444336, 8.162796020507812, 6.904360294342041, 15.086037635803223, 11.995706558227539, 12.244240760803223, 6.719570159912109, 9.748353958129883, 7.373607158660889, 7.568200588226318, 7.0632643699646, 6.794408798217773, 7.4992852210998535, 7.588514804840088, 7.6789231300354, 7.344943523406982, 5.670265197753906, 6.245641708374023, 8.63520336151123, 7.086880683898926, 4.683644771575928, 13.2805757522583, 8.841081619262695, 7.114266872406006, 7.771805763244629, 6.436595439910889, 3.2363266944885254, 12.903082847595215, 6.693779945373535, 9.35282039642334, 5.935966968536377, 7.145452976226807, 5.26027250289917, 7.565586566925049, 4.630558490753174, 3.9179294109344482, 7.0920257568359375, 5.14495849609375, 11.116034507751465, 10.144906997680664, 10.180548667907715, 6.071567058563232, 5.912420272827148, 5.957441806793213, 4.414993762969971, 11.589162826538086, 12.148540496826172, 8.697221755981445, 6.586933612823486, 3.9885950088500977, 7.1154913902282715, 1.0880879163742065, 7.138180732727051, 5.78258752822876, 10.060541152954102, 9.022461891174316, 4.738103866577148, 4.706284523010254, 16.1583251953125, 4.837371826171875, 10.314773559570312, 7.905353546142578, 8.856805801391602, 11.439857482910156, 12.256121635437012, 9.66690731048584, 6.171576023101807, 6.195257186889648, 10.330521583557129, 7.946025848388672, 11.008023262023926, 5.640432357788086, 5.856267929077148, 5.384231090545654, 7.1944122314453125, 4.253411293029785, 6.338245391845703, 3.6887013912200928, 9.069557189941406, 8.764275550842285, 7.1462297439575195, 8.622961044311523, 5.245873928070068, 3.986828088760376, 3.1614303588867188, 7.526843547821045, 9.269258499145508, 7.412954807281494, 7.505431175231934, 12.329818725585938, 11.010259628295898, 8.507984161376953, 4.833631992340088, 10.997504234313965, 4.289799213409424, 10.621211051940918, 9.23244571685791, 8.729961395263672, 4.520023822784424, 11.634332656860352, 8.38453197479248, 6.534638404846191, 6.196821212768555, 8.198404312133789, 4.343478679656982, 5.13840389251709, 11.474964141845703, 6.448729515075684, 17.000093460083008, 2.4251976013183594, 4.753681182861328, 4.813736438751221, 1.1324104070663452, 7.535841464996338, 8.976275444030762, 17.436830520629883, 10.534748077392578, 6.961845397949219, 14.149333000183105, 4.133641719818115, 8.129706382751465, 5.803528785705566, 9.58532428741455, 5.74629020690918, 7.398850440979004, 11.561851501464844, 3.9612834453582764, 3.9940848350524902, 8.434652328491211, 3.2969889640808105, 1.9842207431793213, 5.564487457275391, 6.185572624206543, 9.545723915100098, 6.764242649078369, 4.507950305938721, 10.386082649230957, 5.986346244812012, 5.841550827026367, 7.996969699859619, 8.169111251831055, 10.925585746765137, 8.427659034729004, 4.7360100746154785, 9.405625343322754, 4.25331974029541, 14.376008033752441, 3.1820473670959473, 5.074550151824951, 2.889564275741577, 8.261034965515137, 5.877311706542969, 8.993685722351074, 6.598845481872559, 7.6168389320373535, 10.149747848510742, 17.840648651123047, 8.566028594970703, 4.947693824768066, 5.49942684173584, 8.633851051330566, 6.0543437004089355], 'mean_perplexity': 7.6818472802639}\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "from pathlib import Path\n",
    "\n",
    "predictions = []\n",
    "predict_dir = \"checkpoints/vanilla-llama3-nlp_rw-epoch30-8192-full_finetune/lightning_logs/version_1/predict\"\n",
    "predict_files = list(Path(predict_dir).glob(\"*.txt\"))\n",
    "for file in predict_files:\n",
    "    with open(file, \"r\") as f:\n",
    "        predictions.append(f.read().strip())\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model=model, tokenizer=tokenizer, batch_size=2,\n",
    "                                output_dir = Path(predict_dir))\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/miniconda3/envs/yuh/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/256 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 256/256 [15:38<00:00,  3.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [6.662538528442383, 6.703956127166748, 3.888385772705078, 3.860496759414673, 5.755697727203369, 6.0298566818237305, 5.871049404144287, 5.664289951324463, 6.551420211791992, 5.2591328620910645, 7.692512512207031, 5.2577714920043945, 5.179086685180664, 5.586071491241455, 5.870924472808838, 4.222850799560547, 4.087632656097412, 7.865680694580078, 4.785553455352783, 5.789429664611816, 4.1340203285217285, 3.2448477745056152, 5.420875549316406, 6.158207416534424, 4.479437351226807, 2.862560749053955, 5.847195148468018, 6.166591644287109, 7.180481910705566, 6.164158821105957, 3.264415740966797, 3.9993505477905273, 4.329326629638672, 3.422545909881592, 4.677935600280762, 7.166839599609375, 5.223176956176758, 5.143620491027832, 8.228843688964844, 6.073873043060303, 5.15476131439209, 8.705445289611816, 4.9601826667785645, 5.539163589477539, 8.62806510925293, 8.945265769958496, 3.7210423946380615, 4.114291191101074, 3.8274147510528564, 4.724334716796875, 3.085158109664917, 4.659389495849609, 4.736699104309082, 3.8920302391052246, 7.095478534698486, 5.586363792419434, 4.431929111480713, 4.35570764541626, 5.924088001251221, 7.646994590759277, 7.558742523193359, 5.415041923522949, 5.391783237457275, 5.710630893707275, 6.553561210632324, 5.08038330078125, 24202.970703125, 4.3625993728637695, 5.38886022567749, 4.323681831359863, 3.4986512660980225, 6.741409778594971, 4.378742694854736, 6.967719554901123, 3.647427797317505, 7.260585784912109, 5.412897109985352, 5.22982931137085, 3.1289052963256836, 4.312517166137695, 5.7000603675842285, 3.4615960121154785, 5.027934551239014, 4.197242259979248, 6.184034824371338, 3.6530447006225586, 3.6694018840789795, 5.2166972160339355, 5.269437789916992, 5.408716678619385, 5.507139205932617, 7.085141181945801, 5.628697872161865, 5.584565162658691, 3.9229159355163574, 3.926344394683838, 3.7954013347625732, 3.792998790740967, 5.297114372253418, 4.909903049468994, 5.071468830108643, 4.626134395599365, 5.517248153686523, 5.214006423950195, 6.202589988708496, 4.715069770812988, 6.765402317047119, 5.9384541511535645, 3.7338831424713135, 4.135748386383057, 8.875946998596191, 4.352761268615723, 5.4750657081604, 3.5181949138641357, 5.726487159729004, 3.579441547393799, 4.095937728881836, 8.571152687072754, 6.197169303894043, 6.460610866546631, 3.4665918350219727, 5.4008588790893555, 4.089663982391357, 6.725459575653076, 6.664960861206055, 5.580117702484131, 6.606041431427002, 4.74251651763916, 4.04836368560791, 4.629095077514648, 6.462523460388184, 3.3792965412139893, 6.672108173370361, 7.51210880279541, 5.839513301849365, 4.689689636230469, 7.382263660430908, 3.717721462249756, 4.81402587890625, 4.136185169219971, 3.535712718963623, 7.222692489624023, 4.262279033660889, 5.512328624725342, 5.026167392730713, 4.41909122467041, 4.984361171722412, 4.795871734619141, 4.585757255554199, 8.283149719238281, 3.261279821395874, 7.348974704742432, 4.8275628089904785, 4.444346904754639, 3.97112774848938, 6.875731468200684, 9.139738082885742, 5.706970691680908, 3.1943423748016357, 4.288197994232178, 7.07441520690918, 5.951887607574463, 6.902925491333008, 5.358438968658447, 6.4210052490234375, 4.576847076416016, 2.962554931640625, 7.353837966918945, 3.962972402572632, 3.6312992572784424, 6.671178340911865, 3.8744795322418213, 4.415317058563232, 6.784434795379639, 3.3927018642425537, 5.812861919403076, 5.105876922607422, 4.232609748840332, 6.007431983947754, 4.729918479919434, 4.570559501647949, 5.445316791534424, 5.902717113494873, 4.122618675231934, 3.797703742980957, 4.940896987915039, 3.3522744178771973, 6.149424076080322, 5.275523662567139, 3.994285821914673, 3.411315441131592, 6.909533977508545, 7.8282856941223145, 4.22587251663208, 4.482039928436279, 5.946120738983154, 5.450639724731445, 3.973278522491455, 7.923694610595703, 4.674211025238037, 7.704292297363281, 6.332148551940918, 3.8341851234436035, 4.7883782386779785, 5.201913356781006, 4.4317474365234375, 4.404654026031494, 2.8935000896453857, 6.955092430114746, 6.007475852966309, 5.389462471008301, 6.186148643493652, 3.4111194610595703, 5.937147617340088, 8.46185302734375, 3.8000404834747314, 3.078078269958496, 5.010270595550537, 4.9645562171936035, 4.4533867835998535, 3.955784559249878, 4.237294673919678, 4.083903789520264, 8.830854415893555, 5.014210224151611, 5.593417644500732, 7.272493839263916, 7.347316741943359, 4.14654541015625, 4.201357841491699, 3.8060991764068604, 5.787951469421387, 5.118149280548096, 4.85090446472168, 3.614924669265747, 4.4286417961120605, 4.764253616333008, 5.760751724243164, 4.88280725479126, 6.386399745941162, 4.705770492553711, 5.541849613189697, 4.261627674102783, 4.6020917892456055, 3.3283965587615967, 5.680573463439941, 5.776752948760986, 5.466312408447266, 4.26395320892334, 3.5690412521362305, 3.292269229888916, 3.3537778854370117, 6.56951379776001, 5.386791706085205, 3.9389209747314453, 5.717684268951416, 3.729127883911133, 4.548696517944336, 5.606823444366455, 6.059210300445557, 3.210577964782715, 5.949158191680908, 4.113687992095947, 10.398811340332031, 3.828125238418579, 4.203240871429443, 4.407334327697754, 5.516098976135254, 5.635103225708008, 5.463208198547363, 6.022660732269287, 5.102939605712891, 5.384347915649414, 5.0700907707214355, 5.772679805755615, 7.786012172698975, 6.466671466827393, 4.442397594451904, 5.527129650115967, 5.924559116363525, 7.030980110168457, 6.96299934387207, 7.200747013092041, 3.1799159049987793, 3.906409978866577, 4.60396671295166, 4.604850769042969, 3.038374662399292, 5.622485637664795, 5.152244567871094, 4.279562473297119, 5.348459720611572, 6.248631954193115, 5.729779243469238, 3.8594484329223633, 5.686919689178467, 5.974486827850342, 3.869194746017456, 4.560394287109375, 4.35051965713501, 4.310585021972656, 2.8108153343200684, 4.915740966796875, 3.8159689903259277, 5.410823822021484, 3.8215479850769043, 4.189111709594727, 5.1165266036987305, 4.872072219848633, 4.399116039276123, 3.9895741939544678, 4.0189104080200195, 3.6989574432373047, 4.31076192855835, 5.298948287963867, 3.9632375240325928, 5.997828483581543, 4.0958943367004395, 4.107686519622803, 6.265992641448975, 4.765856742858887, 3.651825428009033, 5.823949337005615, 6.117764949798584, 5.148277759552002, 6.747464656829834, 4.339146137237549, 6.114123821258545, 6.197878360748291, 6.051358222961426, 4.186939716339111, 4.970992088317871, 4.533382415771484, 4.921335220336914, 6.035527229309082, 4.745070457458496, 4.52655553817749, 4.936126232147217, 5.203882217407227, 5.427097320556641, 3.5444090366363525, 5.765077114105225, 4.8613386154174805, 6.558455944061279, 4.534721374511719, 6.137139797210693, 5.296500205993652, 5.192507743835449, 3.8738017082214355, 6.418181896209717, 3.8870232105255127, 5.180581092834473, 6.680762767791748, 4.294639587402344, 4.258012294769287, 3.7183191776275635, 5.1748127937316895, 6.545309066772461, 4.324765205383301, 4.862061500549316, 4.008600234985352, 4.335699081420898, 5.111899375915527, 4.66564416885376, 3.4809393882751465, 4.718833923339844, 3.657006025314331, 4.356086730957031, 3.8542325496673584, 3.7438838481903076, 8.534212112426758, 3.0858120918273926, 4.123644828796387, 6.021222114562988, 5.1715803146362305, 4.017935276031494, 6.97493314743042, 2.9816858768463135, 7.407923221588135, 6.504552364349365, 4.91812801361084, 4.8854546546936035, 7.138632297515869, 4.3353800773620605, 3.718595504760742, 5.792558193206787, 4.084202766418457, 4.839554309844971, 6.38741397857666, 5.588606834411621, 4.314589023590088, 5.828339576721191, 5.161320209503174, 4.016021251678467, 3.5616965293884277, 4.139617443084717, 6.912775039672852, 4.549631118774414, 4.677456855773926, 6.8317952156066895, 5.015533447265625, 5.304661750793457, 24202.970703125, 4.622883319854736, 4.430398464202881, 4.537962436676025, 3.6743156909942627, 5.423048496246338, 3.6111414432525635, 5.905645847320557, 6.9877729415893555, 5.518228530883789, 5.726145267486572, 6.1621575355529785, 3.2489309310913086, 5.3847150802612305, 4.83175802230835, 4.385636806488037, 6.610490322113037, 6.083458423614502, 4.432558536529541, 8.062433242797852, 4.8544602394104, 5.310300350189209, 4.231846809387207, 4.342684268951416, 5.564296722412109, 4.424440860748291, 4.532093048095703, 5.155887126922607, 3.684697389602661, 4.710153579711914, 6.5139617919921875, 7.01156759262085, 6.481410980224609, 6.300573348999023, 7.838099956512451, 4.137085914611816, 6.216036319732666, 3.944736957550049, 3.6403181552886963, 5.976691722869873, 7.778058052062988, 4.572110176086426, 5.164376735687256, 4.269871234893799, 5.084219455718994, 3.6899306774139404, 4.847320556640625, 5.465286731719971, 4.998935222625732, 5.990747451782227, 4.668854713439941, 6.9836883544921875, 6.158501625061035, 5.374502658843994, 4.198220729827881, 10.078145027160645, 5.134241104125977, 4.147216796875, 4.064144611358643, 3.626920223236084, 4.814506530761719, 8.226042747497559, 3.8112165927886963, 4.380001544952393, 3.9852354526519775, 4.493175029754639, 4.788084983825684, 3.9440057277679443, 6.0517964363098145, 5.4362664222717285, 4.260653972625732, 4.7215471267700195, 4.069214344024658, 5.30673360824585, 4.047527313232422, 3.9179513454437256, 5.935842514038086, 3.777833938598633, 5.743401527404785, 6.7486090660095215, 7.218299388885498, 5.386877536773682, 7.001348495483398, 4.55291223526001, 5.205036640167236, 7.967484951019287, 5.853339672088623, 5.747446537017822, 6.275759220123291, 6.374061584472656, 5.981497764587402, 8.006004333496094, 6.06664514541626, 4.685530662536621, 7.09424352645874, 3.6195623874664307, 5.32649040222168, 4.2544660568237305, 6.229556560516357, 4.5400004386901855, 5.6213860511779785, 3.153139591217041, 4.500868320465088, 5.480484962463379, 6.3782806396484375, 3.586812734603882, 4.8326416015625, 3.857355833053589, 7.536346435546875, 4.536384105682373], 'mean_perplexity': 99.71032925974578}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "from pathlib import Path\n",
    "\n",
    "predictions = json.load(open(\"inference/NLP_RW/eval_res_gemini.json\"))\n",
    "predictions = [list(entry.values())[0] for entry in predictions]\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model=model, tokenizer=tokenizer, batch_size=2,\n",
    "                                output_dir = None)\n",
    "print(results)\n",
    "\n",
    "import numpy\n",
    "\n",
    "ppl = np.asarray(results[\"perplexities\"])\n",
    "ppl = ppl[ppl < 1000]\n",
    "\n",
    "print(np.mean(ppl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/miniconda3/envs/yuh/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/250 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 250/250 [09:57<00:00,  2.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [9.312560081481934, 9.221843719482422, 8.040806770324707, 6.024981498718262, 7.501518726348877, 9.539844512939453, 11.127911567687988, 6.607606410980225, 8.762290954589844, 8.258460998535156, 8.11427116394043, 8.171050071716309, 6.917275905609131, 7.734862804412842, 7.986056327819824, 5.662349700927734, 6.705543518066406, 9.045670509338379, 6.573861122131348, 7.82411527633667, 5.186086654663086, 3.861078977584839, 7.401772499084473, 6.681850433349609, 8.068716049194336, 4.556868076324463, 9.046528816223145, 7.144874572753906, 7.762122631072998, 12.060916900634766, 5.450791835784912, 5.359674453735352, 7.533213138580322, 5.097501277923584, 6.10175085067749, 9.89030933380127, 4.886504650115967, 6.456973552703857, 7.6829447746276855, 6.95530366897583, 6.6802592277526855, 9.050185203552246, 5.664628982543945, 10.29240894317627, 9.781145095825195, 10.914595603942871, 5.041409492492676, 5.9026265144348145, 5.638894081115723, 6.4446702003479, 4.258349418640137, 8.110431671142578, 8.156381607055664, 6.767422676086426, 6.905084609985352, 6.469115734100342, 8.686117172241211, 6.164018630981445, 6.756545543670654, 8.679332733154297, 6.5374979972839355, 6.696555137634277, 4.707193374633789, 10.302373886108398, 10.461716651916504, 8.30259895324707, 5.749345302581787, 9.619105339050293, 7.038620471954346, 5.278081893920898, 7.502203464508057, 6.259556293487549, 6.87300443649292, 6.37986421585083, 5.568101406097412, 5.8669562339782715, 7.426413536071777, 9.122325897216797, 8.673739433288574, 9.815473556518555, 8.621908187866211, 5.830793857574463, 4.90763521194458, 6.7868123054504395, 6.93375825881958, 6.465750694274902, 8.453883171081543, 9.810340881347656, 7.226233959197998, 8.434913635253906, 8.100589752197266, 8.971686363220215, 10.116687774658203, 7.32738733291626, 13.137871742248535, 6.051090240478516, 5.0903544425964355, 6.417441368103027, 8.290990829467773, 5.268424034118652, 7.7625908851623535, 5.714381694793701, 8.4337739944458, 4.437535285949707, 6.562226295471191, 10.80435562133789, 7.422957897186279, 9.49061107635498, 5.016573429107666, 7.334643363952637, 5.806483745574951, 7.2157673835754395, 9.288261413574219, 7.971037864685059, 12.030354499816895, 9.717634201049805, 4.378612995147705, 7.632652282714844, 6.145044803619385, 6.611827373504639, 7.432644844055176, 8.539335250854492, 6.320813179016113, 5.2603559494018555, 9.364056587219238, 7.541131019592285, 6.067401885986328, 5.717275619506836, 6.6764817237854, 6.339641571044922, 6.634655952453613, 8.410008430480957, 7.411853790283203, 7.2103705406188965, 6.977002143859863, 6.802545547485352, 5.35194206237793, 11.543116569519043, 4.803343772888184, 11.141949653625488, 6.912421703338623, 6.634244441986084, 5.602391242980957, 6.936130046844482, 9.486367225646973, 5.251814365386963, 6.8840413093566895, 6.403943061828613, 6.979999542236328, 10.267790794372559, 10.446439743041992, 7.906921863555908, 10.211543083190918, 8.264200210571289, 6.627346515655518, 9.921611785888672, 6.237675666809082, 6.006741046905518, 9.430275917053223, 6.255250930786133, 6.083273410797119, 7.444585800170898, 5.252116680145264, 6.816104888916016, 7.05842924118042, 7.824029445648193, 7.03148889541626, 6.37675666809082, 7.395875453948975, 8.734901428222656, 8.198736190795898, 6.350574016571045, 5.801605701446533, 8.02487850189209, 8.249451637268066, 6.611176013946533, 6.995959281921387, 4.579314231872559, 5.530404090881348, 7.560732364654541, 6.348412036895752, 5.660098075866699, 6.975684642791748, 10.730414390563965, 9.026705741882324, 5.67408561706543, 8.434327125549316, 6.275071620941162, 10.335020065307617, 7.935349941253662, 4.688920021057129, 6.650112628936768, 7.029598712921143, 5.942137718200684, 6.077907085418701, 3.909417152404785, 11.017943382263184, 7.101601600646973, 7.294571876525879, 7.33811092376709, 5.4711174964904785, 8.182912826538086, 9.589523315429688, 4.7575154304504395, 6.335957050323486, 7.562441349029541, 8.765088081359863, 7.297943592071533, 7.103304386138916, 9.403589248657227, 6.465813636779785, 7.853869438171387, 7.305665016174316, 8.211438179016113, 7.855367183685303, 8.711992263793945, 6.3360795974731445, 5.411841869354248, 5.730742454528809, 6.366022109985352, 8.389141082763672, 6.18233585357666, 6.386282444000244, 7.307461261749268, 6.353298187255859, 8.092055320739746, 8.017060279846191, 6.8331708908081055, 6.945536136627197, 7.847832679748535, 7.851160049438477, 6.5636444091796875, 4.7340617179870605, 7.316767692565918, 8.532465934753418, 7.027824878692627, 5.834460258483887, 4.715865135192871, 6.738222122192383, 5.86079216003418, 9.48674488067627, 6.327690601348877, 5.736347675323486, 7.439435005187988, 7.689186096191406, 8.113432884216309, 8.463281631469727, 6.305573463439941, 5.377652645111084, 8.439959526062012, 5.607684135437012, 12.707243919372559, 6.169212341308594, 8.968123435974121, 5.939140796661377, 6.144952774047852, 8.323871612548828, 6.453080654144287, 6.810405731201172, 7.343726634979248, 9.842134475708008, 6.460184097290039, 6.530318260192871, 10.261382102966309, 7.558819770812988, 6.881699562072754, 6.412714004516602, 7.462529182434082, 7.855477809906006, 8.159554481506348, 8.563886642456055, 4.996416091918945, 6.8183064460754395, 7.7314372062683105, 5.1806745529174805, 4.658047676086426, 7.295983791351318, 9.480239868164062, 6.5195631980896, 6.549354553222656, 8.178970336914062, 5.369923114776611, 5.01273775100708, 5.415328025817871, 6.847012042999268, 5.98974084854126, 6.289835453033447, 6.016341686248779, 5.305485248565674, 5.145662784576416, 7.327091217041016, 6.513487339019775, 5.854913234710693, 8.336894989013672, 6.055881500244141, 8.294069290161133, 6.896965503692627, 7.916579723358154, 6.767147541046143, 5.442023754119873, 6.625270366668701, 6.4903154373168945, 6.2707953453063965, 5.745615482330322, 9.185056686401367, 7.496598243713379, 6.689723491668701, 7.536632061004639, 7.2466206550598145, 6.689988136291504, 8.730996131896973, 7.309865951538086, 7.520335674285889, 8.429887771606445, 5.214126110076904, 7.408959865570068, 7.69164514541626, 8.868112564086914, 6.062156677246094, 6.9776291847229, 5.822432041168213, 7.195827484130859, 8.635066032409668, 6.686856269836426, 7.806834697723389, 7.201118469238281, 7.1880083084106445, 8.98100471496582, 5.021627902984619, 6.1562299728393555, 7.378082752227783, 8.453081130981445, 8.913155555725098, 6.7438788414001465, 6.952868461608887, 6.179518222808838, 6.082980632781982, 9.000972747802734, 8.28164291381836, 7.720267295837402, 8.00458812713623, 4.71618127822876, 7.75130033493042, 6.0116047859191895, 8.410880088806152, 6.548486709594727, 8.079642295837402, 6.972034931182861, 7.471640586853027, 5.632212162017822, 8.574850082397461, 8.10395622253418, 7.1430768966674805, 7.3466010093688965, 6.368697643280029, 6.484634876251221, 6.902060031890869, 5.309472560882568, 11.500052452087402, 4.548668384552002, 6.402443885803223, 6.648556709289551, 5.073703289031982, 9.100869178771973, 8.00327205657959, 6.736221790313721, 7.088779926300049, 6.74337100982666, 5.689087390899658, 5.424946308135986, 7.563905715942383, 5.9351348876953125, 4.984737396240234, 9.233253479003906, 7.042651653289795, 7.700957298278809, 6.567343235015869, 7.268996715545654, 5.619235992431641, 9.100887298583984, 13.097640037536621, 5.189635753631592, 5.5987091064453125, 5.893795967102051, 9.053448677062988, 5.2161383628845215, 4.860011577606201, 8.852519989013672, 7.19748592376709, 7.166595458984375, 5.631178855895996, 11.448596954345703, 7.970633029937744, 5.529946327209473, 8.359868049621582, 8.579291343688965, 6.077162265777588, 7.044449329376221, 8.275721549987793, 7.306338310241699, 7.933928966522217, 6.7828521728515625, 4.908514976501465, 7.479051113128662, 7.626996994018555, 7.384071350097656, 8.51095199584961, 11.381722450256348, 8.67045783996582, 8.791007995605469, 7.215358734130859, 9.872136116027832, 5.705933570861816, 6.360829830169678, 6.26731538772583, 5.315408229827881, 7.295897483825684, 7.546451091766357, 5.590457916259766, 7.559279441833496, 7.099909782409668, 9.944012641906738, 8.885107040405273, 7.78354549407959, 10.371171951293945, 6.963494777679443, 8.553906440734863, 5.214130401611328, 6.278453826904297, 7.3765997886657715, 15.479964256286621, 6.853878498077393, 7.336277008056641, 7.146265506744385, 8.571422576904297, 7.715205192565918, 6.8458123207092285, 6.1180644035339355, 6.0257744789123535, 6.380343914031982, 5.730459690093994, 8.815149307250977, 6.944349765777588, 5.264153480529785, 7.110106468200684, 11.06009578704834, 6.990967273712158, 5.159672737121582, 7.582636833190918, 4.917535781860352, 7.191940784454346, 8.662249565124512, 7.259420871734619, 6.786746025085449, 5.533170223236084, 7.579611301422119, 5.970799446105957, 5.710256099700928, 8.78826904296875, 6.243644714355469, 6.500482082366943, 5.739875316619873, 7.33734655380249, 6.654483795166016, 7.894730091094971, 5.119368553161621, 7.7143940925598145, 4.864841938018799, 7.439601898193359, 7.5136919021606445, 9.664363861083984, 5.714620590209961, 8.929352760314941, 6.169558525085449, 10.077095031738281, 8.239130973815918, 7.910396575927734, 6.299825668334961, 10.044600486755371, 9.698083877563477, 6.100404739379883, 10.99542236328125, 6.471537113189697, 8.003249168395996, 9.993974685668945, 6.453426837921143, 9.564780235290527, 5.708940505981445, 9.306303024291992, 4.536437034606934, 7.7561140060424805, 5.1568427085876465, 7.142203330993652, 9.284577369689941, 7.056450843811035, 4.792610168457031, 6.762404441833496, 7.333547115325928, 10.054882049560547, 5.7187089920043945], 'mean_perplexity': 7.2773297810554505}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "from pathlib import Path\n",
    "\n",
    "predictions = json.load(open(\"inference/NLP_RW/gpt4o-mini/eval_res_gpt4o-mini.json\"))\n",
    "predictions = [list(entry.values())[0] for entry in predictions]\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model=model, tokenizer=tokenizer, batch_size=2,\n",
    "                                output_dir = None)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 250/250 [14:44<00:00,  3.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [6.760364532470703, 7.253321647644043, 5.072598934173584, 4.263315677642822, 5.974827289581299, 7.646799564361572, 10.481669425964355, 5.767916679382324, 6.563111782073975, 5.768136978149414, 7.354924201965332, 4.897534370422363, 5.496021270751953, 7.622555732727051, 5.840635776519775, 3.9819939136505127, 5.178964138031006, 6.409365653991699, 7.007185459136963, 5.217965126037598, 4.2274580001831055, 3.460437774658203, 5.892084121704102, 6.671872615814209, 4.312008380889893, 3.916489601135254, 5.1720356941223145, 5.788219451904297, 6.6099772453308105, 8.093299865722656, 4.271887302398682, 4.0712809562683105, 5.130293846130371, 4.812620639801025, 5.342617988586426, 6.949527263641357, 4.023703575134277, 4.946648120880127, 7.886467933654785, 6.339351654052734, 5.967156410217285, 9.746685028076172, 4.43704891204834, 6.975742816925049, 5.968456268310547, 6.376980781555176, 4.18852424621582, 5.5012640953063965, 5.587374210357666, 5.432607173919678, 3.941983699798584, 5.322903156280518, 5.07615852355957, 4.918703556060791, 4.870357990264893, 5.592143535614014, 6.564578056335449, 4.592405796051025, 4.105345726013184, 7.506558418273926, 5.5519256591796875, 5.7933125495910645, 4.3608269691467285, 6.457507133483887, 10.028326034545898, 4.971149921417236, 4.088505268096924, 6.650249481201172, 5.000826358795166, 3.865445852279663, 5.434328079223633, 4.461536884307861, 4.977989196777344, 5.392114639282227, 4.708802700042725, 4.826531887054443, 5.328532695770264, 5.886035442352295, 4.249091148376465, 7.325138092041016, 8.420778274536133, 4.832456588745117, 5.0511250495910645, 5.2323174476623535, 5.8764328956604, 5.078305244445801, 4.86873722076416, 6.823729991912842, 5.657869338989258, 5.820645809173584, 6.9827880859375, 5.563462257385254, 6.6841139793396, 5.553593158721924, 8.594972610473633, 5.368991374969482, 6.236850738525391, 4.3676934242248535, 7.3758158683776855, 4.815007209777832, 6.325582981109619, 4.924196720123291, 6.4298481941223145, 3.8681321144104004, 4.686444282531738, 6.960890293121338, 5.203207492828369, 7.998506546020508, 3.831228733062744, 4.845139026641846, 4.252604961395264, 5.362921714782715, 6.341238975524902, 6.434725284576416, 6.211270332336426, 5.763537883758545, 3.591996669769287, 5.218696117401123, 5.444282054901123, 6.648094654083252, 4.843792915344238, 6.674583911895752, 6.569207668304443, 4.960056304931641, 8.698943138122559, 6.2795257568359375, 5.379208564758301, 6.082219123840332, 5.873574733734131, 7.059335708618164, 4.915650367736816, 7.119204998016357, 4.975472450256348, 5.41631555557251, 5.127082824707031, 4.9675469398498535, 4.984793663024902, 7.256518840789795, 4.065883636474609, 6.81244421005249, 5.435606002807617, 3.921901226043701, 4.723994731903076, 5.6094489097595215, 5.897407531738281, 4.544162273406982, 3.9051473140716553, 5.535574913024902, 5.240645885467529, 9.548104286193848, 6.702853202819824, 6.325390338897705, 5.612001419067383, 5.224496841430664, 4.478315830230713, 7.145849704742432, 4.177754878997803, 3.9254486560821533, 6.698530197143555, 4.548393249511719, 4.841246128082275, 6.47455358505249, 4.797677516937256, 6.03676176071167, 5.265818119049072, 9.032440185546875, 5.836031913757324, 6.581517219543457, 4.884466171264648, 5.417908191680908, 6.376159191131592, 5.308248996734619, 4.710636615753174, 6.87875509262085, 5.598840713500977, 6.408730983734131, 6.435343265533447, 4.41728401184082, 5.37245512008667, 6.972075939178467, 7.887876033782959, 4.195741176605225, 4.405188083648682, 8.937284469604492, 6.400877475738525, 5.119266033172607, 6.183807849884033, 4.263365745544434, 6.97926664352417, 6.592281818389893, 3.992121458053589, 6.223937034606934, 5.107494831085205, 5.68265962600708, 5.407548427581787, 3.3061013221740723, 7.8041791915893555, 6.143197059631348, 5.463844299316406, 4.8974833488464355, 3.5409903526306152, 7.007785320281982, 11.166836738586426, 6.666378498077393, 4.975729942321777, 4.327641010284424, 7.8859639167785645, 5.794493675231934, 5.602377891540527, 6.113709926605225, 4.635875701904297, 5.9120588302612305, 5.381917953491211, 6.621859550476074, 4.9615864753723145, 5.404289245605469, 4.93181037902832, 5.5644025802612305, 4.525747776031494, 4.916431427001953, 6.828763961791992, 6.602478981018066, 3.9263570308685303, 4.256958961486816, 5.061564922332764, 5.603464603424072, 5.8582024574279785, 5.264402866363525, 5.812119483947754, 6.653538227081299, 4.957103729248047, 3.728381872177124, 4.114070415496826, 5.5085577964782715, 6.679239273071289, 6.222227096557617, 5.622066020965576, 3.9252922534942627, 5.689344882965088, 4.16047477722168, 8.220654487609863, 5.612486362457275, 4.5407233238220215, 5.444851875305176, 4.819954872131348, 4.723607063293457, 5.347068786621094, 5.209199905395508, 4.003062725067139, 6.819573402404785, 5.114058017730713, 13.483453750610352, 4.0599541664123535, 4.8581223487854, 5.35324764251709, 7.003511905670166, 6.450134754180908, 5.712134838104248, 6.320497512817383, 5.048724174499512, 5.917469024658203, 4.341012001037598, 4.553154468536377, 7.0019426345825195, 6.419118881225586, 6.080375671386719, 5.139658451080322, 7.098749160766602, 6.4516730308532715, 7.7601318359375, 7.641818523406982, 4.356046676635742, 4.44087028503418, 5.210386753082275, 4.820004463195801, 4.060737609863281, 5.563511371612549, 4.847728729248047, 5.630509853363037, 5.281097888946533, 4.921994209289551, 4.3355793952941895, 3.714122772216797, 5.212831974029541, 5.461049556732178, 5.062434196472168, 4.538909435272217, 4.409420967102051, 4.052426815032959, 3.180950880050659, 6.314389228820801, 4.735016822814941, 5.103275299072266, 5.644498348236084, 4.467624187469482, 5.256933212280273, 4.77923059463501, 5.8684821128845215, 4.11031436920166, 5.281602382659912, 6.009018421173096, 4.886351585388184, 6.410782337188721, 3.8095221519470215, 5.6959919929504395, 5.0030927658081055, 6.3926262855529785, 5.2983527183532715, 4.941037178039551, 4.172575950622559, 6.802154064178467, 5.514207363128662, 6.791955947875977, 7.46196174621582, 4.3515424728393555, 4.980888843536377, 5.432982921600342, 5.237070560455322, 4.480713844299316, 4.633219242095947, 5.374973297119141, 6.417896747589111, 7.221532821655273, 5.390798568725586, 5.715476989746094, 4.896486282348633, 5.037192344665527, 6.395817279815674, 4.062927722930908, 6.646440505981445, 5.590949058532715, 5.044816493988037, 5.743856430053711, 5.480149269104004, 4.332553386688232, 5.971588134765625, 4.554636001586914, 6.324791431427002, 4.708901882171631, 4.899785995483398, 6.484497547149658, 4.690489768981934, 6.297364234924316, 4.094101905822754, 7.145374774932861, 5.652587890625, 6.54557466506958, 4.284666538238525, 5.968079090118408, 5.298973560333252, 6.1399126052856445, 5.520694732666016, 5.7780303955078125, 3.777456521987915, 4.761498928070068, 4.964026927947998, 4.897927284240723, 4.493646621704102, 9.037558555603027, 3.498986005783081, 4.472965717315674, 4.933645248413086, 4.243432521820068, 5.432306289672852, 6.150631427764893, 4.586724281311035, 6.164276599884033, 5.820185661315918, 4.536027908325195, 4.979557514190674, 9.156320571899414, 4.92293119430542, 4.0805344581604, 5.770482063293457, 5.300227165222168, 6.027963161468506, 5.070460319519043, 4.458511829376221, 4.8012590408325195, 6.292223930358887, 9.405464172363281, 3.505561351776123, 4.817250728607178, 5.559386730194092, 7.379927158355713, 4.617277145385742, 3.415250062942505, 6.5774827003479, 4.632955074310303, 5.2085652351379395, 5.3631672859191895, 6.865860939025879, 6.1891374588012695, 3.8615074157714844, 4.527386665344238, 7.069385528564453, 4.42758321762085, 5.161080360412598, 7.3694376945495605, 8.18382740020752, 4.975366115570068, 7.764530658721924, 4.38252067565918, 7.41975212097168, 7.651338577270508, 5.458767890930176, 5.763463497161865, 7.8061089515686035, 7.632701396942139, 5.632829666137695, 4.970534324645996, 9.95888900756836, 5.768163681030273, 4.624879837036133, 4.765697956085205, 4.8517866134643555, 5.1542134284973145, 5.410863876342773, 5.066503047943115, 6.060723781585693, 5.3578009605407715, 6.190036296844482, 6.666743755340576, 5.562911033630371, 8.0006103515625, 6.77481746673584, 6.243293285369873, 4.506988525390625, 3.6664352416992188, 6.393021106719971, 5.935453414916992, 4.789637565612793, 5.197983741760254, 4.711825370788574, 5.10800838470459, 7.0611572265625, 5.434595584869385, 5.265028476715088, 4.823402404785156, 5.070223808288574, 5.209002494812012, 6.15423583984375, 5.565286159515381, 5.045297622680664, 4.79868221282959, 7.677227973937988, 4.402411460876465, 4.305357456207275, 5.194305896759033, 3.986377239227295, 5.466111660003662, 7.235400199890137, 4.623812675476074, 6.026928424835205, 4.6269965171813965, 5.855003356933594, 6.712419033050537, 4.601466178894043, 7.465353012084961, 5.082119941711426, 6.416214466094971, 6.497832298278809, 4.111502170562744, 6.223653316497803, 4.368839740753174, 4.825159549713135, 5.985386848449707, 4.3474297523498535, 5.883011817932129, 5.982375621795654, 8.325874328613281, 5.791121959686279, 7.224799633026123, 4.336668968200684, 6.399385929107666, 6.791551113128662, 6.648420810699463, 4.762259006500244, 7.209417819976807, 7.271295070648193, 5.025713920593262, 5.85904598236084, 4.305496692657471, 5.2666015625, 7.079563140869141, 5.722559928894043, 7.136363983154297, 4.599996566772461, 6.506136894226074, 3.865513563156128, 7.120850086212158, 3.72727370262146, 4.8833770751953125, 5.728496074676514, 5.214585304260254, 4.057071208953857, 5.840620517730713, 6.369965076446533, 7.27946662902832, 4.882765769958496], 'mean_perplexity': 5.632317275047302}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "from pathlib import Path\n",
    "\n",
    "predictions = json.load(open(\"inference/NLP_RW/deepseek/eval_res_deepseek_chat_v2.json\"))\n",
    "predictions = [list(entry.values())[0] for entry in predictions]\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model=model, tokenizer=tokenizer, batch_size=2,\n",
    "                                output_dir = None)\n",
    "print(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.14 ('yuh')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "92619f85c62eb73280c07ca2268c8e47b90999a589aa097a1a08a504bd2fb2c6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
